{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.002482Z",
     "start_time": "2025-05-11T12:01:58.339772Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import collections\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ],
   "id": "6260919fc640c59b",
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.139721Z",
     "start_time": "2025-05-11T12:02:01.125586Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#@save\n",
    "class Seq2SeqEncoder(d2l.Encoder):\n",
    "    \"\"\"用于序列到序列学习的循环神经网络编码器\"\"\"\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
    "                 dropout=0, **kwargs):\n",
    "        super(Seq2SeqEncoder, self).__init__(**kwargs)\n",
    "        # 嵌入层\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,\n",
    "                          dropout=dropout)\n",
    "\n",
    "    def forward(self, X, *args):\n",
    "        # 输出'X'的形状：(batch_size,num_steps,embed_size)\n",
    "        X = self.embedding(X)\n",
    "        # 在循环神经网络模型中，第一个轴对应于时间步\n",
    "        X = X.permute(1, 0, 2)\n",
    "        # 如果未提及状态，则默认为0\n",
    "        output, state = self.rnn(X)\n",
    "        # output的形状:(num_steps,batch_size,num_hiddens)\n",
    "        # state的形状:(num_layers,batch_size,num_hiddens)\n",
    "        return output, state"
   ],
   "id": "ca78b98142048f19",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.186652Z",
     "start_time": "2025-05-11T12:02:01.156646Z"
    }
   },
   "cell_type": "code",
   "source": [
    "encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
    "                         num_layers=2)\n",
    "encoder.eval()\n",
    "X = torch.zeros((4, 7), dtype=torch.long)\n",
    "output, state = encoder(X)\n",
    "output.shape"
   ],
   "id": "d2e23abd24bc366d",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 4, 16])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.218120Z",
     "start_time": "2025-05-11T12:02:01.202373Z"
    }
   },
   "cell_type": "code",
   "source": "state.shape",
   "id": "e2a723e1c384e906",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 16])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.264883Z",
     "start_time": "2025-05-11T12:02:01.250532Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class Seq2SeqDecoder(d2l.Decoder):\n",
    "    \"\"\"用于序列到序列学习的循环神经网络解码器\"\"\"\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,\n",
    "                 dropout=0, **kwargs):\n",
    "        super(Seq2SeqDecoder, self).__init__(**kwargs)\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,\n",
    "                          dropout=dropout)\n",
    "        self.dense = nn.Linear(num_hiddens, vocab_size)\n",
    "\n",
    "    def init_state(self, enc_outputs, *args):\n",
    "        return enc_outputs[1]\n",
    "\n",
    "    def forward(self, X, state):\n",
    "        # 输出'X'的形状：(batch_size,num_steps,embed_size)\n",
    "        X = self.embedding(X).permute(1, 0, 2)\n",
    "        # 广播context，使其具有与X相同的num_steps\n",
    "        context = state[-1].repeat(X.shape[0], 1, 1)\n",
    "        X_and_context = torch.cat((X, context), 2)\n",
    "        output, state = self.rnn(X_and_context, state)\n",
    "        output = self.dense(output).permute(1, 0, 2)\n",
    "        # output的形状:(batch_size,num_steps,vocab_size)\n",
    "        # state的形状:(num_layers,batch_size,num_hiddens)\n",
    "        return output, state"
   ],
   "id": "c9ea464d8a77899",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.295307Z",
     "start_time": "2025-05-11T12:02:01.281780Z"
    }
   },
   "cell_type": "code",
   "source": [
    "decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,\n",
    "                         num_layers=2)\n",
    "decoder.eval()\n",
    "state = decoder.init_state(encoder(X))\n",
    "output, state = decoder(X, state)\n",
    "output.shape, state.shape"
   ],
   "id": "51b5c31f0714b72c",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.326370Z",
     "start_time": "2025-05-11T12:02:01.310603Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#@save\n",
    "def sequence_mask(X, valid_len, value=0):\n",
    "    \"\"\"在序列中屏蔽不相关的项\"\"\"\n",
    "    maxlen = X.size(1)\n",
    "    mask = torch.arange((maxlen), dtype=torch.float32,\n",
    "                        device=X.device)[None, :] < valid_len[:, None]\n",
    "    X[~mask] = value\n",
    "    return X\n",
    "\n",
    "X = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
    "sequence_mask(X, torch.tensor([1, 2]))"
   ],
   "id": "ff752780b4395dbf",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 0, 0],\n",
       "        [4, 5, 0]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.403946Z",
     "start_time": "2025-05-11T12:02:01.390148Z"
    }
   },
   "cell_type": "code",
   "source": [
    "X = torch.ones(2, 3, 4)\n",
    "sequence_mask(X, torch.tensor([1, 2]), value=-1)"
   ],
   "id": "4faf2378f0a28378",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  1.,  1.,  1.],\n",
       "         [-1., -1., -1., -1.],\n",
       "         [-1., -1., -1., -1.]],\n",
       "\n",
       "        [[ 1.,  1.,  1.,  1.],\n",
       "         [ 1.,  1.,  1.,  1.],\n",
       "         [-1., -1., -1., -1.]]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.481515Z",
     "start_time": "2025-05-11T12:02:01.467298Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#@save\n",
    "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n",
    "    \"\"\"带遮蔽的softmax交叉熵损失函数\"\"\"\n",
    "    # pred的形状：(batch_size,num_steps,vocab_size)\n",
    "    # label的形状：(batch_size,num_steps)\n",
    "    # valid_len的形状：(batch_size,)\n",
    "    def forward(self, pred, label, valid_len):\n",
    "        weights = torch.ones_like(label)\n",
    "        weights = sequence_mask(weights, valid_len)\n",
    "        self.reduction='none'\n",
    "        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n",
    "            pred.permute(0, 2, 1), label)\n",
    "        weighted_loss = (unweighted_loss * weights).mean(dim=1)\n",
    "        return weighted_loss"
   ],
   "id": "24241ffda113c089",
   "outputs": [],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:01.527669Z",
     "start_time": "2025-05-11T12:02:01.512067Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#@save\n",
    "def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n",
    "    \"\"\"训练序列到序列模型\"\"\"\n",
    "    def xavier_init_weights(m):\n",
    "        if type(m) == nn.Linear:\n",
    "            nn.init.xavier_uniform_(m.weight)\n",
    "        if type(m) == nn.GRU:\n",
    "            for param in m._flat_weights_names:\n",
    "                if \"weight\" in param:\n",
    "                    nn.init.xavier_uniform_(m._parameters[param])\n",
    "\n",
    "    net.apply(xavier_init_weights)\n",
    "    net.to(device)\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "    loss = MaskedSoftmaxCELoss()\n",
    "    net.train()\n",
    "    animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
    "                     xlim=[10, num_epochs])\n",
    "    for epoch in range(num_epochs):\n",
    "        timer = d2l.Timer()\n",
    "        metric = d2l.Accumulator(2)  # 训练损失总和，词元数量\n",
    "        for batch in data_iter:\n",
    "            optimizer.zero_grad()\n",
    "            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
    "            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],\n",
    "                          device=device).reshape(-1, 1)\n",
    "            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学\n",
    "            Y_hat, _ = net(X, dec_input, X_valid_len)\n",
    "            l = loss(Y_hat, Y, Y_valid_len)\n",
    "            l.sum().backward()      # 损失函数的标量进行“反向传播”\n",
    "            d2l.grad_clipping(net, 1)\n",
    "            num_tokens = Y_valid_len.sum()\n",
    "            optimizer.step()\n",
    "            with torch.no_grad():\n",
    "                metric.add(l.sum(), num_tokens)\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            animator.add(epoch + 1, (metric[0] / metric[1],))\n",
    "    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '\n",
    "        f'tokens/sec on {str(device)}')"
   ],
   "id": "ac794ae6829e8964",
   "outputs": [],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-11T12:02:02.022415Z",
     "start_time": "2025-05-11T12:02:01.544754Z"
    }
   },
   "cell_type": "code",
   "source": [
    "embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1\n",
    "batch_size, num_steps = 64, 10\n",
    "lr, num_epochs, device = 0.005, 300, d2l.try_gpu()\n",
    "\n",
    "train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)\n",
    "encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,\n",
    "                        dropout)\n",
    "decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,\n",
    "                        dropout)\n",
    "net = d2l.EncoderDecoder(encoder, decoder)\n",
    "train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)"
   ],
   "id": "ecf7e1beb65a03ab",
   "outputs": [
    {
     "ename": "UnicodeDecodeError",
     "evalue": "'gbk' codec can't decode byte 0xaf in position 33: illegal multibyte sequence",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mUnicodeDecodeError\u001B[0m                        Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[11], line 5\u001B[0m\n\u001B[0;32m      2\u001B[0m batch_size, num_steps \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m64\u001B[39m, \u001B[38;5;241m10\u001B[39m\n\u001B[0;32m      3\u001B[0m lr, num_epochs, device \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0.005\u001B[39m, \u001B[38;5;241m300\u001B[39m, d2l\u001B[38;5;241m.\u001B[39mtry_gpu()\n\u001B[1;32m----> 5\u001B[0m train_iter, src_vocab, tgt_vocab \u001B[38;5;241m=\u001B[39m \u001B[43md2l\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mload_data_nmt\u001B[49m\u001B[43m(\u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnum_steps\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m      6\u001B[0m encoder \u001B[38;5;241m=\u001B[39m Seq2SeqEncoder(\u001B[38;5;28mlen\u001B[39m(src_vocab), embed_size, num_hiddens, num_layers,\n\u001B[0;32m      7\u001B[0m                         dropout)\n\u001B[0;32m      8\u001B[0m decoder \u001B[38;5;241m=\u001B[39m Seq2SeqDecoder(\u001B[38;5;28mlen\u001B[39m(tgt_vocab), embed_size, num_hiddens, num_layers,\n\u001B[0;32m      9\u001B[0m                         dropout)\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\d2l\\torch.py:931\u001B[0m, in \u001B[0;36mload_data_nmt\u001B[1;34m(batch_size, num_steps, num_examples)\u001B[0m\n\u001B[0;32m    929\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mload_data_nmt\u001B[39m(batch_size, num_steps, num_examples\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m600\u001B[39m):\n\u001B[0;32m    930\u001B[0m \u001B[38;5;250m    \u001B[39m\u001B[38;5;124;03m\"\"\"Return the iterator and the vocabularies of the translation dataset.\"\"\"\u001B[39;00m\n\u001B[1;32m--> 931\u001B[0m     text \u001B[38;5;241m=\u001B[39m preprocess_nmt(\u001B[43mread_data_nmt\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[0;32m    932\u001B[0m     source, target \u001B[38;5;241m=\u001B[39m tokenize_nmt(text, num_examples)\n\u001B[0;32m    933\u001B[0m     src_vocab \u001B[38;5;241m=\u001B[39m d2l\u001B[38;5;241m.\u001B[39mVocab(source, min_freq\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2\u001B[39m,\n\u001B[0;32m    934\u001B[0m                           reserved_tokens\u001B[38;5;241m=\u001B[39m[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m<pad>\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m<bos>\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m<eos>\u001B[39m\u001B[38;5;124m'\u001B[39m])\n",
      "File \u001B[1;32m~\\anaconda3\\envs\\d2l\\lib\\site-packages\\d2l\\torch.py:862\u001B[0m, in \u001B[0;36mread_data_nmt\u001B[1;34m()\u001B[0m\n\u001B[0;32m    860\u001B[0m data_dir \u001B[38;5;241m=\u001B[39m d2l\u001B[38;5;241m.\u001B[39mdownload_extract(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfra-eng\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m    861\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(os\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39mjoin(data_dir, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfra.txt\u001B[39m\u001B[38;5;124m'\u001B[39m), \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[1;32m--> 862\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mread\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n",
      "\u001B[1;31mUnicodeDecodeError\u001B[0m: 'gbk' codec can't decode byte 0xaf in position 33: illegal multibyte sequence"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "#@save\n",
    "def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,\n",
    "                    device, save_attention_weights=False):\n",
    "    \"\"\"序列到序列模型的预测\"\"\"\n",
    "    # 在预测时将net设置为评估模式\n",
    "    net.eval()\n",
    "    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [\n",
    "        src_vocab['<eos>']]\n",
    "    enc_valid_len = torch.tensor([len(src_tokens)], device=device)\n",
    "    src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])\n",
    "    # 添加批量轴\n",
    "    enc_X = torch.unsqueeze(\n",
    "        torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n",
    "    enc_outputs = net.encoder(enc_X, enc_valid_len)\n",
    "    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n",
    "    # 添加批量轴\n",
    "    dec_X = torch.unsqueeze(torch.tensor(\n",
    "        [tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)\n",
    "    output_seq, attention_weight_seq = [], []\n",
    "    for _ in range(num_steps):\n",
    "        Y, dec_state = net.decoder(dec_X, dec_state)\n",
    "        # 我们使用具有预测最高可能性的词元，作为解码器在下一时间步的输入\n",
    "        dec_X = Y.argmax(dim=2)\n",
    "        pred = dec_X.squeeze(dim=0).type(torch.int32).item()\n",
    "        # 保存注意力权重（稍后讨论）\n",
    "        if save_attention_weights:\n",
    "            attention_weight_seq.append(net.decoder.attention_weights)\n",
    "        # 一旦序列结束词元被预测，输出序列的生成就完成了\n",
    "        if pred == tgt_vocab['<eos>']:\n",
    "            break\n",
    "        output_seq.append(pred)\n",
    "    return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq"
   ],
   "id": "d1d346fc781c4f27"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def bleu(pred_seq, label_seq, k):  #@save\n",
    "    \"\"\"计算BLEU\"\"\"\n",
    "    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')\n",
    "    len_pred, len_label = len(pred_tokens), len(label_tokens)\n",
    "    score = math.exp(min(0, 1 - len_label / len_pred))\n",
    "    for n in range(1, k + 1):\n",
    "        num_matches, label_subs = 0, collections.defaultdict(int)\n",
    "        for i in range(len_label - n + 1):\n",
    "            label_subs[' '.join(label_tokens[i: i + n])] += 1\n",
    "        for i in range(len_pred - n + 1):\n",
    "            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:\n",
    "                num_matches += 1\n",
    "                label_subs[' '.join(pred_tokens[i: i + n])] -= 1\n",
    "        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))\n",
    "    return score"
   ],
   "id": "b7e79f6636e1be0e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .']\n",
    "fras = ['va !', 'j\\'ai perdu .', 'il est calme .', 'je suis chez moi .']\n",
    "for eng, fra in zip(engs, fras):\n",
    "    translation, attention_weight_seq = predict_seq2seq(\n",
    "        net, eng, src_vocab, tgt_vocab, num_steps, device)\n",
    "    print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')"
   ],
   "id": "95135bd819979a3b"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.22"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
